{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cffb1aa4-7225-424c-8a47-86d8ec91672b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!uv pip install umap-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9d655f8-8bef-4da5-a59b-77a744f35bc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import umap\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "import plotly.express as px\n",
    "\n",
    "#import dask.dataframe as dd\n",
    "#from dask import delayed, compute\n",
    "#from dask.diagnostics import ProgressBar\n",
    "# from tqdm.auto import tqdm  # for notebooks\n",
    "\n",
    "# Create new `pandas` methods which use `tqdm` progress\n",
    "# (can use tqdm_gui, optional kwargs, etc.)\n",
    "tqdm.pandas()\n",
    "\n",
    "def load_features(fn):\n",
    "    return np.load(f'/data/ECG_AF/encoder_9_features_it2_val/{fn}').mean(axis=0).astype(np.float16).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32d83005-3609-449d-ad4d-53437a0f032a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('/data/ECG_AF/val_self_supervised_processed.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b25256be-0f98-4965-a94f-a251d4eaee04",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_select = df #.sample(50000)\n",
    "#fv = df_select.filename.progress_apply(load_features).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c556b1c-34e6-40e8-8ee1-36de64418cc5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Convert the sampled DataFrame to a Dask DataFrame\n",
    "ddf = dd.from_pandas(df_select, npartitions=32)\n",
    "\n",
    "# Apply the function in parallel using dask.delayed\n",
    "fv_delayed = ddf['filename'].apply(lambda fn: delayed(load_features)(fn), meta=('filename', object))\n",
    "\n",
    "# Compute the results with a progress bar\n",
    "with ProgressBar():\n",
    "    fv = compute(*fv_delayed.compute())  # Triggers the computation and returns the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a879460-0759-4f7f-a00d-c6aff5650503",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_select.to_parquet('df_select.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36ee3ac3-2881-48e7-88de-19a19c304940",
   "metadata": {},
   "outputs": [],
   "source": [
    "fv_arr = np.array(list(fv))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b357c5c2-42d4-4ce6-80d5-5895091e4e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fv_arr.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "512f1307-2f9d-43fe-9771-2a5ebfde0743",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"fv.npy\", fv_arr.astype(np.float16))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "630f0a0b-a389-4697-96ea-4a31a6a03dab",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42e4f22c-f35a-41ad-b1d2-89a20bc46bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_select = pd.read_parquet('df_select.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b934a1-c8ed-4a15-9413-682eaa207c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_select.loc[df_select.dataset == 'CPSC-EXTRA','dataset']  = 'CPSC'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1fe353-283d-4057-a15e-e153e2c92d5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fv_arr = np.load('fv.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b875aab6-07a9-4f22-b45b-910fb58f1be9",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = umap.UMAP()\n",
    "embedding = reducer.fit_transform(fv_arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19f0195c-a203-4212-a1f7-80bd1329cdae",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_order = [\n",
    "    'CHAPMAN','CPSC', 'CPSC-EXTRA',\n",
    "    'GEORGIA','HEFEI','Mimic', 'NINGBO',\n",
    "    'PTB', 'PTB-XL', 'RIBEIRO'  , 'Samitrop']\n",
    "mapping = {n:i for n,i in zip(ds_order, range(len(ds_order)))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e606d4f-a61e-4bc4-b078-55a3aacb0867",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51aa91b5-fe86-493f-9d1b-16d59afe11bf",
   "metadata": {},
   "source": [
    "'Mimic', 'Samitrop', 'Georgia', 'Ptb', 'Ningbo', 'Ribeiro', 'Hefei','Cpsc', 'Sph', 'Ptbxl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f658c89a-b3d0-4260-8f29-1d30c52478eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = sns.color_palette(\"viridis\", n_colors=len(ds_order))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f7e0e10-8805-4b85-bf33-757bd9292e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 10))\n",
    "plt.scatter(\n",
    "    embedding[:, 0],\n",
    "    embedding[:, 1],\n",
    "    c=[colors[x] for x in df_select.dataset.map(mapping)],\n",
    "    s=20\n",
    ")\n",
    "#plt.gca().set_aspect('equal', 'datalim')\n",
    "plt.title('UMAP projection', fontsize=20);\n",
    "plt.savefig('umap_projection.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2218bd39-7608-4a61-9701-cc85ee80db53",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_select['embedding0'] = embedding[:,0]\n",
    "df_select['embedding1'] = embedding[:,1]\n",
    "\n",
    "fig = px.scatter(\n",
    "    x=df_select.embedding0, y=df_select.embedding1,\n",
    "    color=df_select.dataset, \n",
    "    labels={'color': 'Dataset'},\n",
    "    category_orders={'dataset': ds_order},\n",
    "    color_discrete_sequence=px.colors.sequential.Viridis,\n",
    "    )\n",
    "#fig.write_html(\"umap.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c61780-0faa-4be2-a4e5-9143d8ea6fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
